Llama架构详解
(徒手搓LLM)逐行代码从0构造一个LLM——LlaMa篇 - 知乎
Copy of LLM学习-从0构建一个自己的LLM .ipynb - Colab
Attention
输入参数
hidden_states
:- 输入的隐藏状态张量,形状为
(batch_size, sequence_length, hidden_size)
。 - 每个 token 在特定层中的表示。
- 输入的隐藏状态张量,形状为
attention_mask
(可选):- 掩码,用于指示哪些位置应该被注意力机制忽略(如填充位置)。
position_ids
(可选):- 用于计算旋转位置编码 (RoPE) 的位置索引。
past_key_value
(可选):- 用于缓存先前计算的
key
和value
,以支持增量推理。
- 用于缓存先前计算的
output_attentions
:- 是否输出注意力权重。
use_cache
:- 是否启用缓存机制。
cache_position
(可选):- 缓存中与位置相关的参数,用于与 RoPE 结合。
position_embeddings
(可选):- 外部提供的旋转位置编码的 cos 和 sin 值。
- `\kwargs`**:
- 允许传递其他附加参数。
代码分解
1. 计算 Query/Key/Value 矢量
使用线性投影从
hidden_states
中计算 Query、Key 和 Value:query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states)
如果启用了
pretraining_tp
(多张量并行,用于加速预训练),线性投影会切分成多个张量分别计算,并在最后拼接。
2. Reshape Query/Key/Value
将投影后的结果重塑为
(batch_size, num_heads, seq_length, head_dim)
,并交换维度:
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
3. 应用 RoPE(旋转位置编码)
如果
position_embeddings
未提供,会根据position_ids
动态计算cos
和sin
:cos, sin = self.rotary_emb(value_states, position_ids)
然后通过
apply_rotary_pos_emb
将位置编码融合到 Query 和 Key 上。
4. 增量推理缓存
如果
past_key_value
不为空,表示是增量推理模式,缓存会被更新:
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
5. 扩展 Key/Value(适配多头组)
如果
num_key_value_groups
小于
num_heads
,会通过重复 Key/Value 矢量来适配:
key_states = repeat_kv(key_states, self.num_key_value_groups)
6. 计算注意力权重
使用 Scaled Dot-Product 注意力公式:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
如果存在
attention_mask
,会加上掩码以忽略无效位置:attn_weights = attn_weights + causal_mask
使用 softmax 归一化,并应用 dropout:
attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
7. 计算注意力输出
将注意力权重与
value_states
相乘得到输出:attn_output = torch.matmul(attn_weights, value_states)
检查输出形状是否正确,并调整维度:
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, -1)
8. 输出投影
将注意力输出通过线性层:
attn_output = self.o_proj(attn_output)
9. 返回结果
- 最终返回:
- 注意力输出
attn_output
- 注意力权重
attn_weights
(如果output_attentions=True
) - 更新后的缓存
past_key_value
(如果使用了缓存)
- 注意力输出
总结
这段代码的核心功能是实现一个高效、灵活的注意力机制,支持以下特性:
- 标准多头注意力:通过 Query/Key/Value 计算。
- 旋转位置编码 (RoPE):改进的位置编码方案。
- 增量推理:缓存机制减少重复计算。
- 并行化优化:支持多张量并行。
这是一个高度优化且兼容的注意力模块,适用于大规模 Transformer 模型(如 GPT 系列)。
举个例子,对于LlamaSdpaAttention。
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
"""
算出qkv后,q的维度直接就是多头注意力的attention size,kv则是kv size
q [batch size, attention size, num tokens,hidden dim]
kv [batch size, kv size, num tokens,hidden dim]
"""
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
"""
在存入kv cache后,kv再变化为attention size
qkv [batch size, attention size, num tokens,hidden dim]
"""
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
"""
attn ouput的最初结果和qkv一致 [batch size, attention size, num tokens,hidden dim]
"""
if(hidden_states.shape[1]>1):
print(f"{self.layer_idx}: {attn_output.shape}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
"""
attn ouput再重构为 [batch size, num tokens, attention size * hidden dim]
"""
if(hidden_states.shape[1]>1):
print(f"{self.layer_idx}: {attn_output.shape}")
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value